Skip to content

Fix Flash Attention 3 API compatibility for window size parameters#2704

Open
jhvmhg wants to merge 10 commits intoNVIDIA:mainfrom
jhvmhg:fix/flash_attn3_support_CP
Open

Fix Flash Attention 3 API compatibility for window size parameters#2704
jhvmhg wants to merge 10 commits intoNVIDIA:mainfrom
jhvmhg:fix/flash_attn3_support_CP

Conversation

@jhvmhg
Copy link

@jhvmhg jhvmhg commented Feb 25, 2026

Replace single window_size parameter with window_size_left and window_size_right in flash_attn_fwd function to align with flash-attn v2.7.0+ API changes.

  • Update function signature in flash_attn_interface
  • Maintain backward compatibility where possible
  • Ensure consistency with Flash Attention v2 implementation

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

  1. Fix window size parameters in flash_attn_fwd - Replaces the single window_size parameter with separate window_size_left and window_size_right parameters to match the updated flash-attn v2.7.0+ API.
  2. Fix causal parameter naming in flash_attn_bwd - Renames causal to is_causal in the backward function signature for consistency with the latest flash-attn interface.

Motivation:

The flash-attn library v2.7.0+ introduced breaking API changes that cause compatibility issues with TransformerEngine's Flash Attention 3 integration. These updates ensure seamless operation with newer versions of the flash-attn library while maintaining correctness of both forward and backward attention computations.

Related API Changes:

flash-attn v2.7.0+ split window_size into window_size_left and window_size_right
flash-attn v3+ renamed causal parameter to is_causal in backward pass

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Replace single window_size parameter with window_size_left and window_size_right
    in flash_attn_fwd function to align with flash-attn v2.7.0+ API changes.
  • Rename causal parameter to is_causal in flash_attn_bwd function to align
    with flash-attn v3

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Replace single window_size parameter with window_size_left and window_size_right
in flash_attn_fwd function to align with flash-attn v2.7.0+ API changes.

- Update function signature in flash_attn_interface
- Maintain backward compatibility where possible
- Ensure consistency with Flash Attention v2 implementation

Signed-off-by: Chaoyang Mei <1192554423@qq.com>
Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 25, 2026

Greptile Summary

This PR fixes Flash Attention 3 API incompatibilities in the context-parallel attention path by splitting the legacy window_size tuple parameter into separate window_size_left / window_size_right parameters for FA3 and flash-attn v2.7+, and renaming the backward-pass causal keyword to is_causal for FA3.

Key Finding:
A condition-priority bug remains in the shared helper functions cp_p2p_fwd_flash_attn, cp_p2p_bwd_flash_attn (at two locations), and AttnFuncWithCPAndKVAllGather.backward. The fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus condition is checked before use_flash_attn_3, causing incorrect parameter passing when FA3 is co-installed with flash-attn v2.3–v2.6. In that scenario, the code sets the legacy window_size tuple, but the actual function called is the FA3 version (_flash_attn_fwd_v3/_flash_attn_bwd_v3), which expects window_size_left/window_size_right, resulting in a runtime error.

Impact:

  • Safe to merge for the common deployments: FA3-only or FA3+v2.7+ (the two scenarios work correctly)
  • Will fail at runtime if FA3 is co-installed with flash-attn v2.3–v2.6 (uncommon but valid configuration)

The core logic is sound for mainstream deployments, but the condition-ordering issue must be fixed to support all supported version combinations.

Confidence Score: 2/5

  • Safe for FA3-only or FA3+v2.7+ deployments, but will fail at runtime if FA3 is co-installed with flash-attn v2.3–v2.6 due to condition-priority issue with window-size parameter passing.
  • The core API compatibility fix is correct and handles the two most common deployments (FA3-only and FA3+v2.7+). However, a condition-priority bug causes the code to pass legacy window_size tuple parameters to FA3 functions when flash-attn v2.3–v2.6 is co-installed, resulting in runtime errors. Until the condition order is fixed in the affected helper functions, the PR cannot be considered fully safe across all supported configurations.
  • transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py — specifically the condition chains in cp_p2p_fwd_flash_attn (line 940), cp_p2p_bwd_flash_attn (lines 1192, 1204), and AttnFuncWithCPAndKVAllGather.backward (line 3213).

Last reviewed commit: 01b4539

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 25, 2026

Additional Comments (1)

transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
removed causal parameter but other flash_attn_bwd calls in this file (lines 3222, 3832) still pass it - verify this inconsistency is intentional

Rename causal parameter to is_causal in flash_attn_bwd function to align
with flash-attn v2.7.0+ API changes. This ensures consistency with the
updated flash-attn library interface for backward pass operations.

Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
@jhvmhg jhvmhg force-pushed the fix/flash_attn3_support_CP branch from a245229 to f9752ca Compare February 25, 2026 07:54
Copy link
Author

@jhvmhg jhvmhg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix Flash Attention 3 backward API parameter naming

Rename causal parameter to is_causal in flash_attn_bwd function to align
with flash-attn v2.7.0+ API changes. This ensures consistency with the
updated flash-attn library interface for backward pass operations.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

jhvmhg and others added 2 commits February 25, 2026 15:56
Rename causal parameter to is_causal in flash_attn_bwd function to align
with flash-attn v2.7.0+ API changes. This ensures consistency with the
updated flash-attn library interface for backward pass operations.

Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Replace keyword arguments with positional arguments in flash_attn_fwd and
flash_attn_bwd to abstract away parameter naming differences (causal vs
is_causal) between flash-attn versions. This provides a more robust
interface that is resilient to future API changes in the flash-attn library.

- Convert window_size_left, window_size_right, and causal parameters to
  positional args in both forward and backward functions
- Eliminate version-specific parameter naming dependencies
- Simplify compatibility handling across flash-attn v2.7.0+ variants

Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

softmax_lse_per_step[i],
*fa_backward_args_thd,
causal="causal" in ctx.attn_mask_type,
ctx.attn_mask_type,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ctx.attn_mask_type is a string (e.g., "causal", "no_mask"), not a boolean. Should be "causal" in ctx.attn_mask_type to convert to boolean.

Suggested change
ctx.attn_mask_type,
"causal" in ctx.attn_mask_type,

@jhvmhg jhvmhg closed this Feb 25, 2026
Rename causal parameter to is_causal in flash_attn_bwd function to align
with flash-attn v3 API changes. This ensures consistency with the
updated flash-attn library interface for backward pass operations.

Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
@jhvmhg jhvmhg reopened this Feb 25, 2026
@jhvmhg jhvmhg closed this Feb 25, 2026
@jhvmhg jhvmhg reopened this Feb 25, 2026
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@ptrendx ptrendx requested a review from cyanguwa February 25, 2026 19:57
@cyanguwa cyanguwa requested a review from mk-61 February 26, 2026 00:04
@cyanguwa
Copy link
Collaborator

@mk-61 I think the changes look good, but could you please follow through with the CI, especially the L3_FA_version tests, to make sure the new changes pass the SWA tests for FA3? Thanks!

@vcherepanov-nv
Copy link
Collaborator

LGTM

@sudhakarsingh27
Copy link
Collaborator

/te-ci pytorch L3

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (1)

transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py, line 944
Condition-priority bug for FA3 + flash-attn v2.3–v2.6

When use_flash_attn_3=True AND flash-attn v2 is installed between versions 2.3 and 2.6 (fa_utils.v2_3_plus=True, fa_utils.v2_7_0_plus=False), the first condition fires and sets the legacy window_size tuple parameter. However, the actual function called is _flash_attn_fwd_v3, which requires window_size_left and window_size_right parameters, resulting in a runtime error.

The use_flash_attn_3 check should take priority to ensure FA3 always receives the new API regardless of which flash-attn v2 version is co-installed.

Suggested fix for this location and similar ones at lines 1192–1208 and 3213–3217:

if use_flash_attn_3 or fa_utils.v2_7_0_plus:
    fa_forward_kwargs["window_size_left"] = -1
    fa_forward_kwargs["window_size_right"] = -1
elif fa_utils.v2_3_plus:
    fa_forward_kwargs["window_size"] = (-1, -1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants